""" LangDT Policy (Actor) Implementation """
from functools import partial
from typing import Any, Tuple, NamedTuple, List, Dict, Union, Type, Optional, Callable

import gym
import jax
import jax.numpy as jnp
import haiku as hk
import optax
import transformers
from transformers.models.gpt2.configuration_gpt2 import GPT2Config

from sb3_jax.common.policies import BasePolicy
from sb3_jax.common.norm_layers import BaseNormLayer
from sb3_jax.common.jax_layers import init_weights, BaseFeaturesExtractor, FlattenExtractor
from sb3_jax.common.preprocessing import get_flattened_obs_dim, get_act_dim
from sb3_jax.common.type_aliases import GymEnv, MaybeCallback, Schedule
from sb3_jax.common.utils import get_dummy_obs, get_dummy_act
from sb3_jax.dt.gpt2 import GPT2Model, init_embed

from diffgro.utils.utils import print_b 


class TrajectoryModel(hk.Module):
    """ (language, s1, a1, s2, ...) """
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        hidden_size: int = 128,
        max_length: int = None,
        max_ep_length: int = 4096,
        squash_action: bool = False, 
        config: transformers.GPT2Config = None
    ):
        super().__init__()
        self.observation_dim = get_flattened_obs_dim(observation_space) 
        self.action_dim = get_act_dim(action_space)
        self.hidden_size = hidden_size
        self.max_length = max_length
        self.max_ep_length = max_ep_length
        self.squash_action = squash_action
        self.config = config 

    def __call__(
        self,
        traj_dict: Dict[str, jnp.ndarray],
        languages: jnp.ndarray, # lang embed
        attention_mask: jnp.ndarray = None,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        # [batch_size, max_length, dim]
        observations, actions, timesteps = traj_dict['obs'], traj_dict['act'], traj_dict['t']
        batch_size, seq_length = observations.shape[0], observations.shape[1]
        # [batch_size, skill_dim]
        languages = languages.reshape(batch_size, 1, -1)
        num_tokens = languages.shape[1]

        if attention_mask is None:
            # attention mask for GPT: 1 if can be attened to, - if not 
            attention_mask = jnp.ones((batch_size, seq_length), dtype=jnp.int32)
        
        # embed each modality with a different head
        observation_embed = hk.Linear(self.hidden_size, **init_weights())(observations)
        action_embed = hk.Linear(self.hidden_size, **init_weights())(actions)
        timestep_embed = hk.Embed(self.max_ep_length, self.hidden_size, **init_embed())(timesteps)  
        lang_embed = hk.Linear(self.hidden_size, **init_weights())(languages)

        # timestep embeddings are treated similar to positional embeddings
        observation_embed = observation_embed + timestep_embed
        action_embed = action_embed + timestep_embed

        # this makes the sequence look like (s_1, a_1, s_2, a_2, ...)
        # which works nice in an autoregressive sense since states predict actions
        stacked_inputs = jnp.stack(
            (observation_embed, action_embed), axis=1
        ).transpose(0, 2, 1, 3).reshape(batch_size, 2*seq_length, self.hidden_size)
        stacked_inputs = jnp.concatenate([lang_embed, stacked_inputs], axis=1)
        stacked_inputs = hk.LayerNorm(-1, create_scale=True, create_offset=True)(stacked_inputs)
        
        # to make the attention mask fit the stacked inputs, have to stack it as well
        stacked_attention_mask = jnp.stack(
            (attention_mask, attention_mask), axis=1
        ).transpose(0, 2, 1).reshape(batch_size, 2*seq_length)
        stacked_attention_mask = jnp.concatenate([jnp.ones((batch_size, num_tokens), jnp.int32), stacked_attention_mask], axis=1) 

        # we feed in the input embeddings (not word indices as in NLP) to the model
        transformer_outputs = GPT2Model(self.config)(
            input_embeds=stacked_inputs,
            attention_mask=stacked_attention_mask,
            deterministic=deterministic
        )
        x = transformer_outputs["last_hidden_state"]

        # reshape x so that the second dimension corresponds to the original
        # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
        lang_out = x[:,:num_tokens,:].reshape(batch_size, num_tokens, 1, self.hidden_size).transpose(0, 2, 1, 3)  # [b, 3, 1, d]
        traj_out = x[:,num_tokens:,:].reshape(batch_size, seq_length, 2, self.hidden_size).transpose(0, 2, 1, 3)  # [b, 3, 1, d]

        # predict next state given state and action
        observation_preds = hk.Linear(self.observation_dim, **init_weights())(traj_out[:,1])[:,:-1,:] 
        action_preds = hk.Linear(self.action_dim, **init_weights())(traj_out[:,0])
        if self.squash_action:
            action_preds = jax.nn.tanh(action_preds)
        return (observation_preds, action_preds)


class Actor(BasePolicy):
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        net_arch: Optional[List[int]] = None, # not used
        activation_fn: str = 'gelu_new',
        domain: str = 'short',
        skill_dim: int = 512,
        # gpt config
        max_length: int = None, # horizon
        max_ep_length: int = None,
        hidden_size: int = 128,
        n_layer: int = 2,
        n_head: int = 4,
        n_inner: int = None,
        n_positions: int = 1024,
        resid_pdrop: float = None,
        attn_pdrop: float = None,
        seed: int = 1,
    ):
        super(Actor, self).__init__(
            observation_space,
            action_space,
            squash_output=False,
            seed=seed,
        )

        self.net_arch = net_arch
        self.activation_fn = activation_fn 

        self.domain = domain
        print_b(f"Setting doamin as {self.domain}")
        
        # gpt2 config
        self.skill_dim = skill_dim
        self.max_length = max_length
        self.max_ep_length = max_ep_length
        self.hidden_size = hidden_size
        self.n_layer = n_layer
        self.n_head = n_head, 
        self.n_inner = n_inner
        self.n_positions = n_positions 
        self.resid_pdrop = resid_pdrop
        self.attn_pdrop = attn_pdrop
        
        self.obs_dim = get_flattened_obs_dim(self.observation_space)
        self.act_dim = get_act_dim(self.action_space)

        self._build()

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()
        return data

    def _build_act(self,) -> hk.Module:
        config = transformers.GPT2Config(
            vocab_size=1, # doesn't matter -- we don't use the vocab
            hidden_size=self.hidden_size, 
            n_layer=self.n_layer,
            n_head=self.n_head,
            n_inner=self.n_inner,
            activation_function=self.activation_fn,
            n_positions=self.n_positions,
            resid_pdrop=self.resid_pdrop,
            attn_pdrop=self.attn_pdrop,
        )
        return TrajectoryModel(
            observation_space=self.observation_space, 
            action_space=self.action_space,
            hidden_size=self.hidden_size,
            max_length=self.max_length,
            max_ep_length=self.max_ep_length,
            squash_action=True,
            config=config,
        ) 

    def _build(self) -> None:
        dummy_obs, dummy_act = get_dummy_obs(self.observation_space), get_dummy_act(self.action_space)
        dummy_obs_stack = jnp.repeat(dummy_obs, self.max_length, axis=0).reshape(1, self.max_length, -1) # stacked observation
        dummy_act_stack = jnp.repeat(dummy_act, self.max_length, axis=0).reshape(1, self.max_length, -1) # stacked action
        dummy_t = jnp.arange(0, self.max_length).reshape(1, -1)
        dummy_lang = jax.random.normal(next(self.rng), shape=(1, self.skill_dim)) 
        dummy_skill = jax.random.normal(next(self.rng), shape=(1, self.skill_dim)) 
        dummy_mask = jnp.ones((1, self.max_length), jnp.int32)

        def fn_act(traj_dict: Dict[str, jax.Array], lang: jax.Array, mask: jax.Array, deterministic: bool):
            act = self._build_act()
            return act(traj_dict, lang, mask, deterministic)
        params, self.pi = hk.transform_with_state(fn_act)
        traj_dict = {"obs": dummy_obs_stack, "act": dummy_act_stack, "t": dummy_t} 
        if self.domain == 'long':
            dummy_lang = jnp.concatenate([dummy_lang, dummy_skill], axis=-1)
        self.params, self.state = params(next(self.rng), traj_dict, dummy_lang, dummy_mask, deterministic=False)
    
    @partial(jax.jit, static_argnums=(0,5))
    def _pi(
        self,
        traj_dict: Dict[str, jnp.ndarray],
        mask: jax.Array,
        lang: jax.Array,
        skill: jax.Array,
        deterministic: bool,
        params: hk.Params,
        state: hk.Params,
        rng=None,
    ) -> Tuple[Tuple[jax.Array], jax.Array]:
        # (obs_pred, act_pred), state
        if self.domain == 'long':
            lang = jnp.concatenate([lang, skill], axis=-1) 
        return self.pi(params, state, rng, traj_dict, lang, mask, deterministic)

    def _predict(
        self,
        obs: jax.Array,
        act: jax.Array,
        t: jax.Array,
        mask: jax.Array,
        lang: jax.Array,
        skill: jax.Array = None,
        deterministic: bool = False,
    ) -> Tuple[jax.Array]:
        traj_dict = {"obs": obs, "act": act, "t": t}
        (obs_pred, act_pred), _ = self._pi(
            traj_dict, mask, lang, skill, deterministic, self.params, self.state, next(self.rng))
        return obs_pred, act_pred

    def _load_jax_params(self, params: Dict[str, hk.Params]) -> None:
        print_b("[lang_dt/actor]: loading params")
        self.params = params["pi_params"]
        self.state = params["pi_state"]


class LangDTPlannerPolicy(BasePolicy):
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        lr_schedule: Schedule,
        net_arch: Optional[List[int]] = None, # not used
        activation_fn: str = 'gelu_new',
        domain: str = 'short',
        skill_dim: int = 512,
        # gpt configs
        max_length: int = None,
        max_ep_length: int = None,
        hidden_size: int = 128,
        n_layer: int = None,
        n_head: int = None,
        n_inner: int = None,
        n_positions: int = 1024,
        resid_pdrop: float = None,
        attn_pdrop: float = None,
        # others
        squash_output: bool = False,
        features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        normalize_images: bool = True,
        optimizer_class: Callable = optax.adamw,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
        normalization_class: Type[BaseNormLayer] = None,
        normalization_kwargs: Optional[Dict[str, Any]] = None,
        seed: int = 1,
    ):
        super(LangDTPlannerPolicy, self).__init__(
            observation_space,
            action_space,
            features_extractor_class,
            features_extractor_kwargs,
            optimizer_class=optimizer_class,
            optimizer_kwargs=optimizer_kwargs,
            normalization_class=normalization_class,
            normalization_kwargs=normalization_kwargs,
            squash_output=squash_output,
            seed=seed,
        )
        
        self.skill_dim = skill_dim
        self.max_length = max_length

        self.domain = domain
        assert self.domain in ['short', 'long'], 'Domain should be either short or long'

        self.activation_fn = activation_fn
        self.act_kwargs = {
            "observation_space": self.observation_space,
            "action_space": self.action_space,
            "activation_fn": self.activation_fn,
            "net_arch": net_arch,
            "domain": domain,
            "skill_dim": skill_dim,
            "max_length": max_length,
            "max_ep_length": max_ep_length,
            "hidden_size": hidden_size,
            "n_layer": n_layer,
            "n_head": n_head,
            "n_inner": n_inner,
            "n_positions": n_positions,
            "resid_pdrop": resid_pdrop,
            "attn_pdrop": attn_pdrop,
            "seed": seed,
        }

        self._build(lr_schedule)

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()

        data.update(
            dict(
                observation_space=self.observation_space,
                action_space=self.action_space, 
                optimizer_class=self.optimizer_class,
                optimizer_kwargs=self.optimizer_kwargs,
                features_extractor_class=self.features_extractor_class,
                features_extractor_kwargs=self.features_extractor_kwargs,
                normalization_class=self.normalization_class,
                normalization_kwargs=self.normalization_kwargs,
            )
        )
        return data
    
    def _build(self, lr_schedule: Tuple[float]) -> None:
        if self.normalization_class is not None:
            self.normalization_layer = self.normalization_class(self.observation_space.shape, **self.normalization_kwargs)
        
        self.act = self.make_act()
        self.act.optim = self.optimizer_class(learning_rate=lr_schedule, **self.optimizer_kwargs)
        self.act.optim_state = self.act.optim.init(self.act.params)

    def make_act(self) -> Actor:
        return Actor(**self.act_kwargs)

    def _predict(
        self,
        obs: jax.Array,
        act: jax.Array,
        t: jax.Array,
        mask: jax.Array,
        lang: jax.Array,
        skill: jax.Array = None,
        deterministic: bool = False,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]: 
        batch_size = obs.shape[0]
        obs = self.preprocess(obs.reshape(-1, self.act.obs_dim), training=False)
        obs = obs.reshape(batch_size, -1, self.act.obs_dim)
        obs_pred, act_pred = self.act._predict(obs, act, t, mask, lang, skill, deterministic)
        return obs_pred, act_pred

    @partial(jax.jit, static_argnums=0)
    def _preprocess(
        self, 
        observations: jnp.ndarray,
        actions: jnp.ndarray, 
        timesteps: jnp.ndarray, 
        attention_mask: jnp.ndarray, 
    ) -> Tuple[jax.Array]:
        observation_dim, action_dim = observations.shape[-1], actions.shape[-1]
        
        observations = observations.reshape(1, -1, observation_dim)
        actions = actions.reshape(1, -1, action_dim)
        timesteps = timesteps.reshape(1, -1) 

        if self.max_length is not None:
            observations = observations[:,-self.max_length:]
            actions = actions[:,-self.max_length:]
            timesteps = timesteps[:,-self.max_length:]

            # pad all tokents to sequence length
            attention_mask = jnp.concatenate([jnp.zeros(self.max_length-observations.shape[1]), jnp.ones(observations.shape[1])], dtype=jnp.float32).reshape(1, -1)
            observations = jnp.concatenate(
                [jnp.zeros((observations.shape[0], self.max_length-observations.shape[1], observation_dim)), observations], axis=1, dtype=jnp.float32)
            actions = jnp.concatenate(
                [jnp.zeros((actions.shape[0], self.max_length-actions.shape[1], action_dim)), actions], axis=1, dtype=jnp.float32)
            timesteps = jnp.concatenate(
                [jnp.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1])), timesteps], axis=1, dtype=jnp.int32)
        else:
            attention_mask = None
        
        return observations, actions, timesteps, attention_mask
